import json
from concurrent.futures import ThreadPoolExecutor
import tqdm

from bedrock.client import BedrockClient
from botocore.exceptions import ClientError

from superduper import Model, logging

from typing import List

from dataclasses import dataclass

@dataclass
class BedrockAnthropicChatCompletions(Model):
    """ A class to generate chat completions using Bedrock's Anthropic models"""
    # Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/bedrock-runtime_example_bedrock-runtime_InvokeModel_AnthropicClaude_section.html

    signature: str = 'singleton'
    foundation_model: str = None
    aws_region: str = None
    aws_access_key_id: str = None
    aws_secret_access_key: str = None
    batch_size: int = 1
    prompt: str = ''

    def __post_init__(self, db, artifacts, example):
        if not self.aws_region or not self.aws_access_key_id or not self.aws_secret_access_key or not self.foundation_model:
            raise ValueError("aws_region, aws_access_key_id, aws_secret_access_key, and foundation_model must be provided and cannot be None.")
        self.bedrock_client = BedrockClient(
            aws_access_key=self.aws_access_key_id,
            aws_secret_key=self.aws_secret_access_key,
            region_name=self.aws_region
        )._get_bedrock_client()
        return super().__post_init__(db, artifacts, example)

    def predict(self, text: str):
        """ Predict a chat completion based on the input text.

        Returns:
            str: The chat completion generated by the model.
        """

        temperature = 0.00001
        max_tokens = 512

        # Format the request payload using the model's native structure.
        native_request = {
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": max_tokens,
            "temperature": temperature,
            "messages": [
                {
                    "role": "user",
                    "content": [{"type": "text", "text": text}],
                }
            ],
        }

        # Convert the native request to JSON.
        request = json.dumps(native_request)
        
        try:
            # Invoke the model with the request.
            response = self.bedrock_client.invoke_model(modelId=self.foundation_model, body=request)

        except (ClientError, Exception) as e:
            logging.error(f"ERROR: Can't invoke '{self.foundation_model}'. Reason: {e}")
            exit(1)

        # Decode the response body.
        model_response = json.loads(response["body"].read())

        # Extract the response text.
        response_text = model_response["content"][0]["text"]

        return response_text

    def predict_batches(self, texts: List, num_threads=10):
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            results = List(tqdm.tqdm(executor.map(self.predict, texts), total=len(texts)))
        return results


# if __name__ == '__main__':

#     chat_completion_model = "anthropic.claude-3-haiku-20240307-v1:0"  # You can change this to any other model
#     aws_access_key_id = ""
#     aws_secret_access_key = ""
#     aws_region = ""

#     # Example usage of the BedrockAnthropicChatCompletions class.
#     chat_completion = BedrockAnthropicChatCompletions(
#         identifier='chat-completion',
#         foundation_model=chat_completion_model,
#         aws_access_key_id=aws_access_key_id,
#         aws_secret_access_key=aws_secret_access_key,
#         aws_region=aws_region
#     )

#     # Example prompt for the chat completion model.
#     prompt = "What is the meaning of life?"

#     # Generate a chat completion based on the prompt.
#     result = chat_completion.predict(prompt)

#     print(result)
